{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import feast\n",
    "\n",
    "actual_version = feast.__version__\n",
    "assert actual_version == os.environ.get(\"FEAST_VERSION\"), (\n",
    "    f\"❌ Feast version mismatch. Expected: {os.environ.get('FEAST_VERSION')}, Found: {actual_version}\"\n",
    ")\n",
    "print(f\"✅ Successfully installed Feast version: {actual_version}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd /opt/app-root/src/feature_repo\n",
    "!ls -l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat /opt/app-root/src/feature_repo/feature_store.yaml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p data\n",
    "!wget -O data/city_wikipedia_summaries_with_embeddings.parquet https://raw.githubusercontent.com/opendatahub-io/feast/master/examples/rag/feature_repo/data/city_wikipedia_summaries_with_embeddings.parquet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "\n",
    "df = pd.read_parquet(\"./data/city_wikipedia_summaries_with_embeddings.parquet\")\n",
    "df['vector'] = df['vector'].apply(lambda x: x.tolist())\n",
    "embedding_length = len(df['vector'][0])\n",
    "assert embedding_length == 384, f\"❌ Expected vector length 384, but got {embedding_length}\"\n",
    "print(f'embedding length = {embedding_length}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display\n",
    "\n",
    "display(df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q pymilvus[milvus_lite] transformers torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "\n",
    "# Run `feast apply` and capture output\n",
    "result = subprocess.run([\"feast\", \"apply\"], capture_output=True, text=True)\n",
    "\n",
    "# Combine stdout and stderr in case important info is in either\n",
    "output = result.stdout + result.stderr\n",
    "\n",
    "# Print full output for debugging (optional)\n",
    "print(output)\n",
    "\n",
    "# Expected substrings to validate\n",
    "expected_messages = [\n",
    "    \"Applying changes for project rag\",\n",
    "    \"Connecting to Milvus in local mode\",\n",
    "    \"Deploying infrastructure for city_embeddings\"\n",
    "]\n",
    "\n",
    "# Validate all expected messages are in output\n",
    "for msg in expected_messages:\n",
    "    assert msg in output, f\"❌ Expected message not found: '{msg}'\"\n",
    "\n",
    "print(\"✅ All expected messages were found in the output.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime\n",
    "from feast import FeatureStore\n",
    "\n",
    "store = FeatureStore(repo_path=\".\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import io\n",
    "import sys\n",
    "\n",
    "# Capture stdout\n",
    "captured_output = io.StringIO()\n",
    "sys_stdout_backup = sys.stdout\n",
    "sys.stdout = captured_output\n",
    "\n",
    "# Call the function\n",
    "store.write_to_online_store(feature_view_name='city_embeddings', df=df)\n",
    "\n",
    "# Restore stdout\n",
    "sys.stdout = sys_stdout_backup\n",
    "\n",
    "# Get the output\n",
    "output_str = captured_output.getvalue()\n",
    "\n",
    "# Expected message\n",
    "expected_msg = \"Connecting to Milvus in local mode using data/online_store.db\"\n",
    "\n",
    "# Validate\n",
    "assert expected_msg in output_str, f\"❌ Expected message not found.\\nExpected: {expected_msg}\\nActual Output:\\n{output_str}\"\n",
    "\n",
    "print(\"✅ Output message validated successfully.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# List batch feature views\n",
    "batch_fvs = store.list_batch_feature_views()\n",
    "\n",
    "# Print the number of batch feature views\n",
    "print(\"Number of batch feature views:\", len(batch_fvs))\n",
    "\n",
    "# Assert that the result is an integer and non-negative\n",
    "assert isinstance(len(batch_fvs), int), \"Result is not an integer\"\n",
    "assert len(batch_fvs) >= 0, \"Feature view count is negative\"\n",
    "\n",
    "print(\"Feature views listed correctly ✅\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from feast import FeatureStore\n",
    "\n",
    "# Initialize store (if not already)\n",
    "store = FeatureStore(repo_path=\".\")  # Adjust path if necessary\n",
    "\n",
    "# Retrieve the feature view\n",
    "fv = store.get_feature_view(\"city_embeddings\")\n",
    "\n",
    "# Assert name\n",
    "assert fv.name == \"city_embeddings\", \"Feature view name mismatch\"\n",
    "\n",
    "# Assert entities\n",
    "assert fv.entities == [\"item_id\"], f\"Expected entities ['item_id'], got {fv.entities}\"\n",
    "\n",
    "# Assert feature names and vector index settings\n",
    "feature_names = [f.name for f in fv.features]\n",
    "assert \"vector\" in feature_names, \"Missing 'vector' feature\"\n",
    "assert \"state\" in feature_names, \"Missing 'state' feature\"\n",
    "assert \"sentence_chunks\" in feature_names, \"Missing 'sentence_chunks' feature\"\n",
    "assert \"wiki_summary\" in feature_names, \"Missing 'wiki_summary' feature\"\n",
    "\n",
    "# Assert 'vector' feature is a vector index with COSINE metric\n",
    "vector_feature = next(f for f in fv.features if f.name == \"vector\")\n",
    "assert vector_feature.vector_index, \"'vector' feature is not indexed\"\n",
    "assert vector_feature.vector_search_metric == \"COSINE\", \"Expected COSINE search metric for 'vector'\"\n",
    "\n",
    "print(\"All assertions passed ✅\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from feast.entity import Entity\n",
    "from feast.types import ValueType\n",
    "entity = Entity(\n",
    "    name=\"item_id1\",\n",
    "    value_type=ValueType.INT64,\n",
    "    description=\"test id\",\n",
    "    tags={\"team\": \"feast\"},\n",
    ")\n",
    "store.apply(entity)\n",
    "assert any(e.name == \"item_id1\" for e in store.list_entities())\n",
    "print(\"Entity added ✅\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_to_delete = store.get_entity(\"item_id1\")\n",
    "\n",
    "store.apply(\n",
    "    objects=[],\n",
    "    objects_to_delete=[entity_to_delete],\n",
    "    partial=False\n",
    ")\n",
    "\n",
    "# Validation after deletion\n",
    "assert not any(e.name == \"item_id1\" for e in store.list_entities())\n",
    "print(\"Entity deleted ✅\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# List batch feature views\n",
    "batch_fvs = store.list_batch_feature_views()\n",
    "assert len(batch_fvs) == 1\n",
    "\n",
    "# Print count\n",
    "print(f\"Found {len(batch_fvs)} batch feature view(s) ✅\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pymilvus_client = store._provider._online_store._connect(store.config)\n",
    "COLLECTION_NAME = pymilvus_client.list_collections()[0]\n",
    "\n",
    "milvus_query_result = pymilvus_client.query(\n",
    "    collection_name=COLLECTION_NAME,\n",
    "    filter=\"item_id == '0'\",\n",
    ")\n",
    "pd.DataFrame(milvus_query_result[0]).head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from feast import FeatureStore\n",
    "from pymilvus import MilvusClient, DataType, FieldSchema\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "from example_repo import city_embeddings_feature_view, item\n",
    "\n",
    "TOKENIZER = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
    "MODEL = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
    "\n",
    "def mean_pooling(model_output, attention_mask):\n",
    "    token_embeddings = model_output[\n",
    "        0\n",
    "    ]  # First element of model_output contains all token embeddings\n",
    "    input_mask_expanded = (\n",
    "        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n",
    "    )\n",
    "    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(\n",
    "        input_mask_expanded.sum(1), min=1e-9\n",
    "    )\n",
    "\n",
    "def run_model(sentences, tokenizer, model):\n",
    "    encoded_input = tokenizer(\n",
    "        sentences, padding=True, truncation=True, return_tensors=\"pt\"\n",
    "    )\n",
    "    # Compute token embeddings\n",
    "    with torch.no_grad():\n",
    "        model_output = model(**encoded_input)\n",
    "\n",
    "    sentence_embeddings = mean_pooling(model_output, encoded_input[\"attention_mask\"])\n",
    "    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)\n",
    "    return sentence_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"Which city has the largest population in New York?\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)\n",
    "model = AutoModel.from_pretrained(MODEL)\n",
    "query_embedding = run_model(question, tokenizer, model)\n",
    "query = query_embedding.detach().cpu().numpy().tolist()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display\n",
    "\n",
    "# Retrieve top k documents\n",
    "context_data = store.retrieve_online_documents_v2(\n",
    "    features=[\n",
    "        \"city_embeddings:vector\",\n",
    "        \"city_embeddings:item_id\",\n",
    "        \"city_embeddings:state\",\n",
    "        \"city_embeddings:sentence_chunks\",\n",
    "        \"city_embeddings:wiki_summary\",\n",
    "    ],\n",
    "    query=query,\n",
    "    top_k=3,\n",
    "    distance_metric='COSINE',\n",
    ").to_df()\n",
    "display(context_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_documents(context_df):\n",
    "    output_context = \"\"\n",
    "    unique_documents = context_df.drop_duplicates().apply(\n",
    "        lambda x: \"City & State = {\" + x['state'] +\"}\\nSummary = {\" + x['wiki_summary'].strip()+\"}\",\n",
    "        axis=1,\n",
    "    )\n",
    "    for i, document_text in enumerate(unique_documents):\n",
    "        output_context+= f\"****START DOCUMENT {i}****\\n{document_text.strip()}\\n****END DOCUMENT {i}****\"\n",
    "    return output_context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "RAG_CONTEXT = format_documents(context_data[['state', 'wiki_summary']])\n",
    "print(RAG_CONTEXT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FULL_PROMPT = f\"\"\"\n",
    "You are an assistant for answering questions about states. You will be provided documentation from Wikipedia. Provide a conversational answer.\n",
    "If you don't know the answer, just say \"I do not know.\" Don't make up an answer.\n",
    "\n",
    "Here are document(s) you should use when answer the users question:\n",
    "{RAG_CONTEXT}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install openai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from openai import OpenAI\n",
    "\n",
    "client = OpenAI(\n",
    "    api_key=os.environ.get(\"OPENAI_API_KEY\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = client.chat.completions.create(\n",
    "    model=\"gpt-4o-mini\",\n",
    "    messages=[\n",
    "        {\"role\": \"system\", \"content\": FULL_PROMPT},\n",
    "        {\"role\": \"user\", \"content\": question}\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The expected output\n",
    "expected_output = (\n",
    "    \"New York City\"\n",
    ")\n",
    "\n",
    "# Actual output from response\n",
    "actual_output = '\\n'.join([c.message.content.strip() for c in response.choices])\n",
    "\n",
    "# Validate\n",
    "assert expected_output in actual_output, f\"❌ Output mismatch:\\nExpected: {expected_output}\\nActual: {actual_output}\"\n",
    "\n",
    "print(\"✅ Output matches expected response.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
